{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from flashbax.vault import Vault\n",
    "import numpy as np\n",
    "import os\n",
    "import torch as th\n",
    "import yaml\n",
    "\n",
    "from components.episode_buffer import EpisodeBatch\n",
    "from functools import partial\n",
    "from envs import REGISTRY as env_REGISTRY\n",
    "from components.transforms import OneHot\n",
    "from components.offline_buffer import DataSaver\n",
    "from types import SimpleNamespace as SN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# You can transform og-marl (https://github.com/instadeepai/og-marl) dataset into h5 dataset suitable for offpymarl framework in this file\n",
    "# Use the Google Drive URL(https://drive.google.com/drive/folders/1lw-e5VwIdCtmsGWgQG902yZRArU69TrH) \n",
    "# or follow https://github.com/instadeepai/og-marl/blob/main/examples/download_dataset.py\n",
    "# to download the og-marl dataset\n",
    "# Create 'ogmarl_dataset' folder in offpymarl to store corresponding .vlt dataset\n",
    "\n",
    "# Extra package requirements:\n",
    "# jax==0.4.28\n",
    "# flashbax==0.1.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_path = os.path.join(\"/\".join(os.getcwd().split('/')[:-1]), 'ogmarl_dataset')\n",
    "vault_uid2quality = {\n",
    "    \"Good\": \"expert\",\n",
    "    \"Medium\": \"medium\",\n",
    "    \"Poor\": \"poor\"\n",
    "}\n",
    "# You can change the following parameters according to your needs\n",
    "map_name = \"3m\"\n",
    "og_quality = \"Good\"\n",
    "num_traj_per_file = 10000\n",
    "\n",
    "offpymarl_quality = vault_uid2quality[og_quality]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading vault found at /home/zzq/Project/GitProject/offpymarl/ogmarl_dataset/3m.vlt/Good\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1293609/1188914484.py:4: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).\n",
      "  jax.tree_map(lambda x: x.shape, offline_data)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'actions': (1, 996366, 3),\n",
       " 'infos': {'legals': (1, 996366, 3, 9), 'state': (1, 996366, 48)},\n",
       " 'observations': (1, 996366, 3, 30),\n",
       " 'rewards': (1, 996366, 3),\n",
       " 'terminals': (1, 996366, 3),\n",
       " 'truncations': (1, 996366, 3)}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vlt = Vault(rel_dir=dataset_path, vault_name=f\"{map_name}.vlt\", vault_uid=og_quality)\n",
    "all_data = vlt.read()\n",
    "offline_data = all_data.experience\n",
    "jax.tree_map(lambda x: x.shape, offline_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1293609/1234236242.py:2: YAMLLoadWarning: calling yaml.load() without Loader=... is deprecated, as the default Loader is unsafe. Please read https://msg.pyyaml.org/load for full details.\n",
      "  env_config = yaml.load(f)\n"
     ]
    }
   ],
   "source": [
    "with open(\"config/envs/sc2.yaml\", \"r\") as f:\n",
    "    env_config = yaml.load(f)\n",
    "env_args = SN(**env_config)\n",
    "env_args.env_args['map_name'] = map_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# env_args.env_args"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = env_REGISTRY[env_args.env](**env_args.env_args)\n",
    "env_info = env.get_env_info()\n",
    "for k, v in env_info.items():\n",
    "    setattr(env_args, k, v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "scheme = {\n",
    "    \"state\": {\"vshape\": env_info[\"state_shape\"]},\n",
    "    \"obs\": {\"vshape\": env_info[\"obs_shape\"], \"group\": \"agents\"},\n",
    "    \"actions\": {\"vshape\": (1,), \"group\": \"agents\", \"dtype\": th.long},\n",
    "    \"avail_actions\": {\"vshape\": (env_info[\"n_actions\"],), \"group\": \"agents\", \"dtype\": th.int},\n",
    "    \"reward\": {\"vshape\": (1,)},\n",
    "    \"terminated\": {\"vshape\": (1,), \"dtype\": th.uint8},\n",
    "    \"corrected_terminated\": {\"vshape\": (1,), \"dtype\": th.uint8},\n",
    "}\n",
    "groups = {\n",
    "    \"agents\": env_args.n_agents\n",
    "}\n",
    "preprocess = {\n",
    "    \"actions\": (\"actions_onehot\", [OneHot(out_dim=env_args.n_actions)])\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'state': {'vshape': 48},\n",
       " 'obs': {'vshape': 30, 'group': 'agents'},\n",
       " 'actions': {'vshape': (1,), 'group': 'agents', 'dtype': torch.int64},\n",
       " 'avail_actions': {'vshape': (9,), 'group': 'agents', 'dtype': torch.int32},\n",
       " 'reward': {'vshape': (1,)},\n",
       " 'terminated': {'vshape': (1,), 'dtype': torch.uint8},\n",
       " 'corrected_terminated': {'vshape': (1,), 'dtype': torch.uint8}}"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scheme"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "episode_limit = env.episode_limit\n",
    "new_batch_fn = partial(EpisodeBatch, scheme, groups, 1, episode_limit + 1,\n",
    "                                 preprocess=preprocess, device=\"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 996366, 3, 9) (1, 996366, 48) (1, 996366) (1, 996366, 3) (1, 996366, 3, 30) (1, 996366)\n"
     ]
    }
   ],
   "source": [
    "# from jnp.array -> np.array\n",
    "avail_actions = offline_data[\"infos\"][\"legals\"]\n",
    "states = offline_data[\"infos\"][\"state\"]\n",
    "terminated = jnp.maximum(offline_data[\"terminals\"], offline_data[\"truncations\"])[..., 0]\n",
    "actions = offline_data[\"actions\"]\n",
    "observations = offline_data[\"observations\"]\n",
    "rewards = offline_data[\"rewards\"][..., 0]\n",
    "print(avail_actions.shape, states.shape, terminated.shape, actions.shape, observations.shape, rewards.shape)\n",
    "\n",
    "avail_actions, states, terminated, actions, observations, rewards = np.asarray(avail_actions), np.asarray(states), np.asarray(terminated), np.asarray(actions), np.asarray(observations), np.asarray(rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "avail_actions, states, terminated, actions, observations, rewards = np.asarray(avail_actions), np.asarray(states), np.asarray(terminated), np.asarray(actions), np.asarray(observations), np.asarray(rewards)\n",
    "episode_idxs = np.nonzero(terminated)[1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 23%|██▎       | 10000/43559 [00:32<26:12, 21.34it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Save offline buffer to ../ogmarl_dataset/sc2/3m/expert/part_0.h5 with 10000 episodes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 46%|████▌     | 20000/43559 [01:05<15:27, 25.40it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Save offline buffer to ../ogmarl_dataset/sc2/3m/expert/part_1.h5 with 10000 episodes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 69%|██████▉   | 30000/43559 [01:38<11:49, 19.11it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Save offline buffer to ../ogmarl_dataset/sc2/3m/expert/part_2.h5 with 10000 episodes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 92%|█████████▏| 40000/43559 [02:08<02:10, 27.20it/s]  "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Save offline buffer to ../ogmarl_dataset/sc2/3m/expert/part_3.h5 with 10000 episodes\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 43559/43559 [02:09<00:00, 335.25it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Save offline buffer to ../ogmarl_dataset/sc2/3m/expert/part_4.h5 with 3559 episodes\n"
     ]
    }
   ],
   "source": [
    "save_path = f\"../ogmarl_dataset/sc2/{map_name}/{offpymarl_quality}\"\n",
    "\n",
    "offline_saver = DataSaver(save_path, None, num_traj_per_file)\n",
    "\n",
    "start_idx = 0\n",
    "from tqdm import tqdm\n",
    "for end_idx in tqdm(episode_idxs):\n",
    "    tmp_batch = new_batch_fn()\n",
    "    episode_slice = slice(start_idx, end_idx + 1)\n",
    "    t_slice = slice(0, end_idx - start_idx + 1)\n",
    "    episode_avail_actions = avail_actions[:, episode_slice]\n",
    "    episode_terminated = terminated[:, episode_slice]\n",
    "    # Notice: no last data as \"episode_runner\"! \n",
    "    # for teriminated states s_t, Q(s_t, a_t) will still be updated with Q(s_{t+1},...))\n",
    "    # can not see s_{t+1} now, so we force a new terminated \n",
    "    episode_corrected_terminated = episode_terminated.copy()\n",
    "    episode_corrected_terminated[0][-2] = 1\n",
    "    episode_states = states[:, episode_slice]\n",
    "    episode_actions = actions[:, episode_slice]\n",
    "    episode_observations = observations[:, episode_slice]\n",
    "    episode_rewards = rewards[:, episode_slice]\n",
    "    \n",
    "    transition_data = {\n",
    "        \"state\": episode_states,\n",
    "        \"obs\": episode_observations,\n",
    "        \"actions\": episode_actions,\n",
    "        \"avail_actions\": episode_avail_actions,\n",
    "        \"reward\": episode_rewards,\n",
    "        \"terminated\": episode_terminated,\n",
    "        \"corrected_terminated\": episode_corrected_terminated\n",
    "    }\n",
    "   \n",
    "    tmp_batch.update(transition_data, ts=t_slice)\n",
    "    offline_saver.append(data={\n",
    "        k:tmp_batch[k].clone().cpu() for k in tmp_batch.data.transition_data.keys()\n",
    "    })\n",
    "\n",
    "    start_idx = end_idx + 1\n",
    "\n",
    "offline_saver.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "pymarlx",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
